{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Converting labels from png to nii file\n",
    "\n",
    "\n",
    "### Overview\n",
    "\n",
    "This is the first step for data preparation\n",
    "\n",
    "Input: ground truth labels in `.png` format\n",
    "\n",
    "Output: labels in `.nii` format, indexed by patient id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Once deleted, variables cannot be recovered. Proceed (y/[n])? y\n",
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%reset\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import glob\n",
    "\n",
    "import numpy as np\n",
    "import PIL\n",
    "import matplotlib.pyplot as plt\n",
    "import SimpleITK as sitk\n",
    "import sys\n",
    "sys.path.insert(0, '../../dataloaders/')\n",
    "import niftiio as nio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "example = \"./MR/1/T2SPIR/Ground/IMG-0002-00001.png\" # example of ground-truth file name. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "### search for scan ids\n",
    "ids = os.listdir(\"./MR/\")\n",
    "OUT_DIR = './niis/T2SPIR/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['37',\n",
       " '3',\n",
       " '15',\n",
       " '34',\n",
       " '33',\n",
       " '39',\n",
       " '20',\n",
       " '10',\n",
       " '22',\n",
       " '8',\n",
       " '31',\n",
       " '2',\n",
       " '36',\n",
       " '5',\n",
       " '13',\n",
       " '19',\n",
       " '21',\n",
       " '1',\n",
       " '38',\n",
       " '32']"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image with id 37 has been saved!\n",
      "image with id 3 has been saved!\n",
      "image with id 15 has been saved!\n",
      "image with id 34 has been saved!\n",
      "image with id 33 has been saved!\n",
      "image with id 39 has been saved!\n",
      "image with id 20 has been saved!\n",
      "image with id 10 has been saved!\n",
      "image with id 22 has been saved!\n",
      "image with id 8 has been saved!\n",
      "image with id 31 has been saved!\n",
      "image with id 2 has been saved!\n",
      "image with id 36 has been saved!\n",
      "image with id 5 has been saved!\n",
      "image with id 13 has been saved!\n",
      "image with id 19 has been saved!\n",
      "image with id 21 has been saved!\n",
      "image with id 1 has been saved!\n",
      "image with id 38 has been saved!\n",
      "image with id 32 has been saved!\n"
     ]
    }
   ],
   "source": [
    "#### Write them to nii files for the ease of loading in future\n",
    "for curr_id in ids:\n",
    "    pngs = glob.glob(f'./MR/{curr_id}/T2SPIR/Ground/*.png')\n",
    "    pngs = sorted(pngs, key = lambda x: int(os.path.basename(x).split(\"-\")[-1].split(\".png\")[0]))\n",
    "    buffer = []\n",
    "\n",
    "    for fid in pngs:\n",
    "        buffer.append(PIL.Image.open(fid))\n",
    "\n",
    "    vol = np.stack(buffer, axis = 0)\n",
    "    # flip correction\n",
    "    vol = np.flip(vol, axis = 1).copy()\n",
    "    # remap values\n",
    "    for new_val, old_val in enumerate(sorted(np.unique(vol))):\n",
    "        vol[vol == old_val] = new_val\n",
    "\n",
    "    # get reference    \n",
    "    ref_img = f'./niis/T2SPIR/image_{curr_id}.nii.gz'\n",
    "    img_o = sitk.ReadImage(ref_img)\n",
    "    vol_o = nio.np2itk(img=vol, ref_obj=img_o)\n",
    "    sitk.WriteImage(vol_o, f'{OUT_DIR}/label_{curr_id}.nii.gz')\n",
    "    print(f'image with id {curr_id} has been saved!')\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
